{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c54a3bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "import copy\n",
    "\n",
    "import numpy as np  # linear algebra\n",
    "import pydicom\n",
    "import os\n",
    "import scipy.ndimage\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage import measure, morphology\n",
    "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
    "import cv2\n",
    "from PIL import Image\n",
    "\n",
    "INPUT_FOLDER = '/nfs3-p1/zsxm/dataset/aorta_CTA/zhaoqifeng/img/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f869762",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the scans in given folder path\n",
    "def load_scan(path):\n",
    "    slices = [pydicom.dcmread(path + '/' + s) for s in os.listdir(path)]\n",
    "    slices.sort(key=lambda x: float(x.InstanceNumber))\n",
    "#     try:\n",
    "#         slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])\n",
    "#     except:\n",
    "#         slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)\n",
    "\n",
    "#     for s in slices:\n",
    "#         s.SliceThickness = slice_thickness\n",
    "\n",
    "    return slices\n",
    "\n",
    "\n",
    "def get_pixels_hu(slices):\n",
    "    image = np.stack([s.pixel_array for s in slices])\n",
    "    # Convert to int16 (from sometimes int16),\n",
    "    # should be possible as values should always be low enough (<32k)\n",
    "    image = image.astype(np.int16)\n",
    "\n",
    "    # Set outside-of-scan pixels to 0\n",
    "    # The intercept is usually -1024, so air is approximately 0\n",
    "    image[image == -2000] = 0\n",
    "\n",
    "    # Convert to Hounsfield units (HU)\n",
    "    for slice_number in range(len(slices)):\n",
    "        intercept = slices[slice_number].RescaleIntercept\n",
    "        slope = slices[slice_number].RescaleSlope\n",
    "\n",
    "        if slope != 1:\n",
    "            image[slice_number] = slope * image[slice_number].astype(np.float64)\n",
    "            image[slice_number] = image[slice_number].astype(np.int16)\n",
    "\n",
    "        image[slice_number] += np.int16(intercept)\n",
    "\n",
    "    return np.array(image, dtype=np.int16)\n",
    "\n",
    "\n",
    "# def set_window(image, w_center, w_width):\n",
    "#     image_copy = image.copy()\n",
    "#     for slice_number in range(len(image)):\n",
    "#         image_copy[slice_number] = cv2.GaussianBlur(image_copy[slice_number], (3,3), 1)\n",
    "#         image_copy[slice_number][image_copy[slice_number]>w_center+int(w_width/2)] = np.int16(w_center-int(w_width/2))\n",
    "#         image_copy[slice_number][image_copy[slice_number]<w_center-int(w_width/2)] = np.int16(w_center-int(w_width/2))\n",
    "\n",
    "#     return image_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c06fe650",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient = load_scan(INPUT_FOLDER)\n",
    "patient_pixels = get_pixels_hu(patient)\n",
    "\n",
    "plt.hist(patient_pixels.flatten(), bins=80, color='c')\n",
    "plt.xlabel(\"Hounsfield Units (HU)\")\n",
    "plt.ylabel(\"Frequency\")\n",
    "plt.show()\n",
    "\n",
    "# Show some slice in the middle\n",
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(patient_pixels[80], cmap=plt.cm.gray)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f88078e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_window(image, w_center, w_width):\n",
    "    image_copy = image.copy()\n",
    "    for slice_number in range(len(image)):\n",
    "        image_copy[slice_number] = np.int16(cv2.bilateralFilter(image_copy[slice_number].astype(np.float32), 5, 150, 1))\n",
    "        image_copy[slice_number][image_copy[slice_number]>w_center+int(w_width/2)] = np.int16(w_center-int(w_width/2))\n",
    "        image_copy[slice_number][image_copy[slice_number]<w_center-int(w_width/2)] = np.int16(w_center-int(w_width/2))\n",
    "        \n",
    "#         kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))\n",
    "#         image_copy[slice_number] = cv2.erode(image_copy[slice_number], kernel, iterations=2)\n",
    "#         image_copy[slice_number] = cv2.dilate(image_copy[slice_number], kernel, iterations=2)\n",
    "        zero_image = np.zeros_like(image_copy[slice_number])\n",
    "        zero_image[image_copy[slice_number] > 250] = 255\n",
    "        image_copy[slice_number] = zero_image\n",
    "        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))\n",
    "        image_copy[slice_number] = cv2.erode(image_copy[slice_number], kernel, iterations=1)\n",
    "        image_copy[slice_number] = cv2.dilate(image_copy[slice_number], kernel, iterations=1)\n",
    "    return image_copy.astype(np.uint8)\n",
    "\n",
    "patient_copy = set_window(patient_pixels, 300, 600)\n",
    "# plt.figure(figsize=(10,10))\n",
    "plt.imshow(patient_copy[80, 200:350, 175:325], cmap='Greys')\n",
    "plt.show()\n",
    "# for i in range(len(patient_copy)):\n",
    "#     plt.imshow(patient_copy[i, 200:350, 175:325], cmap=plt.cm.gray)\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13aba51d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Elem():\n",
    "    def __init__(self, key, contour):\n",
    "        self.root = key\n",
    "        self.end = key\n",
    "        self.contours = [contour]\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.contours)\n",
    "    \n",
    "    def append(self, key, contour):\n",
    "        self.end = key\n",
    "        self.contours.append(contour)\n",
    "        \n",
    "    def get_bboxes(self):\n",
    "        self.bboxes = [cv2.boundingRect(c) for c in self.contours]\n",
    "        def process(bbox):\n",
    "            x, y, w, h = bbox\n",
    "            cx, cy = x+w//2, y+h//2\n",
    "            mb = max(w, h)\n",
    "            sx, sy = cx - mb, cy - mb\n",
    "            ex, ey = cx + mb, cy + mb\n",
    "            return sx, sy, ex, ey\n",
    "            \n",
    "        self.bboxes = [process(b) for b in self.bboxes]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a4b21b",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_x = int(patient_copy.shape[2]/2-patient_copy.shape[2]*0.15)\n",
    "end_x = int(patient_copy.shape[2]/2+patient_copy.shape[2]*0.15)\n",
    "start_y = int(patient_copy.shape[1]*0.55-patient_copy.shape[1]*0.15)\n",
    "end_y = int(patient_copy.shape[1]*0.55+patient_copy.shape[1]*0.15)\n",
    "patient_cut = patient_copy[:, start_y:end_y, start_x:end_x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "106425b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(patient_cut.shape)\n",
    "print(patient_cut.dtype)\n",
    "for i in range(len(patient_cut)):\n",
    "    plt.title(str(i))\n",
    "    plt.imshow(patient_cut[i], cmap=\"gray\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0c098d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_intersection(origin, first, second):\n",
    "    zero1, zero2 = np.zeros_like(origin), np.zeros_like(origin)\n",
    "    cv2.fillPoly(zero1, [first], 125)\n",
    "    cv2.fillPoly(zero2, [second], 130)\n",
    "    inter = zero1 + zero2\n",
    "    inter[inter<255] = 0\n",
    "    contours, _ = cv2.findContours(inter, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
    "    inter_area = 0\n",
    "    for contour in contours:\n",
    "        inter_area += cv2.contourArea(contour)\n",
    "    second_area = cv2.contourArea(second)\n",
    "    if second_area == 0:\n",
    "        return 0\n",
    "    assert inter_area <= second_area\n",
    "    return inter_area/second_area"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c6d5f02",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_contours, _ = cv2.findContours(patient_cut[0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
    "pre_circle = list(map(lambda x: cv2.minEnclosingCircle(x), pre_contours))\n",
    "path_dict = {}\n",
    "for i in range(1, len(patient_cut)):\n",
    "    cur_contours, _ = cv2.findContours(patient_cut[i], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n",
    "    cur_circle = list(map(lambda x: cv2.minEnclosingCircle(x), cur_contours))\n",
    "    for j in range(len(pre_circle)):\n",
    "        candidate_list = []\n",
    "        for k in range(len(cur_circle)):\n",
    "            dis = sqrt((pre_circle[j][0][0]-cur_circle[k][0][0])**2+(pre_circle[j][0][1]-cur_circle[k][0][1])**2)\n",
    "            max_r, min_r = max(pre_circle[j][1], cur_circle[k][1]), min(pre_circle[j][1], cur_circle[k][1])\n",
    "            if dis <= max_r - 0.5 * min_r:\n",
    "                candidate_list.append((k, dis, max_r-min_r))\n",
    "        \n",
    "        if len(candidate_list) == 0:\n",
    "            continue\n",
    "        candidate_list.sort(key=lambda x:x[1:3])\n",
    "        if not (i-1, j) in path_dict:\n",
    "            path_dict[(i-1, j)] = Elem((i-1, j), pre_contours[j])\n",
    "        path = path_dict.pop((i-1, j))\n",
    "        temp_list = [copy.deepcopy(path) for s in range(len(candidate_list))]\n",
    "        for s, candidate in enumerate(candidate_list):\n",
    "            k = candidate[0]\n",
    "            if get_intersection(patient_cut[i], pre_contours[j], cur_contours[k]) < 0.8:\n",
    "                continue\n",
    "            temp_list[s].append((i, k), cur_contours[k])\n",
    "            path_dict[(i, k)] = temp_list[s]\n",
    "#         k = candidate_list[0][0]\n",
    "#         path.append((i, k), cur_contours[k])\n",
    "#         path_dict[(i, k)] = path\n",
    "    \n",
    "    pre_contours = cur_contours\n",
    "    pre_circle = cur_circle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dafd1b85",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(path_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaa289f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len = -1\n",
    "for val in path_dict.values():\n",
    "    this_len = len(val)\n",
    "    if this_len > max_len:\n",
    "        max_len = this_len\n",
    "        \n",
    "print(max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3f00b66",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_list = list(path_dict.values())\n",
    "path_list.sort(key=lambda x: len(x), reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1557ff67",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(path_list[0]))\n",
    "print(len(path_list[1]))\n",
    "print(len(path_list[2]))\n",
    "print(len(path_list[3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b2324c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = path_list[0]\n",
    "canvas = np.zeros_like(patient_cut)\n",
    "start = path.root[0]\n",
    "end = path.end[0]+1\n",
    "for i in range(start, end):\n",
    "    cv2.fillPoly(canvas[i], [path.contours[i-start]], 255)\n",
    "    \n",
    "for i in range(len(canvas)):\n",
    "    plt.title(str(i))\n",
    "    plt.imshow(canvas[i], cmap=\"gray\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0710c71a",
   "metadata": {},
   "outputs": [],
   "source": [
    "path.get_bboxes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dccacab",
   "metadata": {},
   "outputs": [],
   "source": [
    "canvas = patient_cut.copy()\n",
    "start = path.root[0]\n",
    "end = path.end[0]+1\n",
    "for i in range(start, end):\n",
    "    cv2.rectangle(canvas[i], path.bboxes[i-start][0:2], path.bboxes[i-start][2:4], 255)\n",
    "    \n",
    "for i in range(len(canvas)):\n",
    "    plt.title(str(i))\n",
    "    plt.imshow(canvas[i], cmap=\"gray\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea68e6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_window2(image, w_center, w_width):\n",
    "    image_copy = image.copy().astype(np.float32)\n",
    "    for slice_number in range(len(image_copy)):\n",
    "        image_copy[slice_number] = np.clip(image_copy[slice_number], w_center-int(w_width/2), w_center+int(w_width/2))\n",
    "        image_copy[slice_number] = (image_copy[slice_number]-image_copy[slice_number].min())/(image_copy[slice_number].max()-image_copy[slice_number].min())\n",
    "    \n",
    "    return image_copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36bb71a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_img = set_window2(patient_pixels, 40, 400)\n",
    "\n",
    "for i in range(start, end):\n",
    "    x1, y1 = path.bboxes[i-start][0] + start_x, path.bboxes[i-start][1] + start_y\n",
    "    x2, y2 = path.bboxes[i-start][2] + start_x, path.bboxes[i-start][3] + start_y\n",
    "    cv2.rectangle(patient_img[i], (x1, y1), (x2, y2), 1, 2)\n",
    "    \n",
    "for i in range(len(patient_img)):\n",
    "    plt.title(str(i))\n",
    "    plt.imshow(patient_img[i], cmap=\"gray\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f4c5ad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_bbox(slices, path, start_x, start_y, save_path):\n",
    "    if not os.path.exists(save_path):\n",
    "        os.mkdir(save_path)\n",
    "    for i, s in enumerate(slices):\n",
    "        \n",
    "        img = s.pixel_array\n",
    "        if path.root[0] <= i <= path.end[0]:\n",
    "            x1, y1 = path.bboxes[i-start][0] + start_x, path.bboxes[i-start][1] + start_y\n",
    "            x2, y2 = path.bboxes[i-start][2] + start_x, path.bboxes[i-start][3] + start_y\n",
    "            cv2.rectangle(img, (x1, y1), (x2, y2), 3000, 1)\n",
    "        s.PixelData = pydicom.encaps.encapsulate([img.tobytes()])\n",
    "        s.file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1'\n",
    "        s.save_as(os.path.join(save_path, f'{i}.dcm'), write_like_original=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0380d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient = load_scan(INPUT_FOLDER)\n",
    "draw_bbox(patient, path, start_x, start_y, '/nfs3-p1/zsxm/dataset/aorta_CTA/zhaoqifeng/save/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90cd09ae",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
